# Import spark libraries
import findspark
findspark.init()
findspark.find()
import pyspark
from pyspark import SparkContext, SparkConf
# import pyspark.sql libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.functions import row_number, monotonically_increasing_id
from pyspark.sql import Window
# Import pyspark.ml libraries
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler
from pyspark.ml import Pipeline
from pyspark.ml.regression import GeneralizedLinearRegression
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
# Import pandas and plotly libraries
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
# Initiate Spark Context
conf = pyspark.SparkConf().setAppName('SparkApp').setMaster('local')
sc = pyspark.SparkContext(conf=conf)
spark = SparkSession(sc)
spark=SparkSession.builder.getOrCreate()
# Get infos about spark session
spark
SparkSession - in-memory
# Import python Operative System Libraries for interaction with OS
import os
from os.path import isfile, join
# Define input data location path
loc = os.path.abspath("")
data_loc = f"{loc}/data/"
# Read file csv that contains header, also get the type information from fields
df_init = spark.read.csv(f"{data_loc}owid_covid.csv",inferSchema=True, header=True)
# Print schema and verify type of fields
df_init.printSchema()
root |-- iso_code: string (nullable = true) |-- continent: string (nullable = true) |-- location: string (nullable = true) |-- date: string (nullable = true) |-- total_cases: double (nullable = true) |-- new_cases: double (nullable = true) |-- new_cases_smoothed: double (nullable = true) |-- total_deaths: double (nullable = true) |-- new_deaths: double (nullable = true) |-- new_deaths_smoothed: double (nullable = true) |-- total_cases_per_million: double (nullable = true) |-- new_cases_per_million: double (nullable = true) |-- new_cases_smoothed_per_million: double (nullable = true) |-- total_deaths_per_million: double (nullable = true) |-- new_deaths_per_million: double (nullable = true) |-- new_deaths_smoothed_per_million: double (nullable = true) |-- reproduction_rate: double (nullable = true) |-- icu_patients: double (nullable = true) |-- icu_patients_per_million: double (nullable = true) |-- hosp_patients: double (nullable = true) |-- hosp_patients_per_million: double (nullable = true) |-- weekly_icu_admissions: double (nullable = true) |-- weekly_icu_admissions_per_million: double (nullable = true) |-- weekly_hosp_admissions: double (nullable = true) |-- weekly_hosp_admissions_per_million: double (nullable = true) |-- new_tests: double (nullable = true) |-- total_tests: double (nullable = true) |-- total_tests_per_thousand: double (nullable = true) |-- new_tests_per_thousand: double (nullable = true) |-- new_tests_smoothed: double (nullable = true) |-- new_tests_smoothed_per_thousand: double (nullable = true) |-- positive_rate: double (nullable = true) |-- tests_per_case: double (nullable = true) |-- tests_units: string (nullable = true) |-- total_vaccinations: double (nullable = true) |-- people_vaccinated: double (nullable = true) |-- people_fully_vaccinated: double (nullable = true) |-- new_vaccinations: double (nullable = true) |-- new_vaccinations_smoothed: double (nullable = true) |-- total_vaccinations_per_hundred: double (nullable = true) |-- people_vaccinated_per_hundred: double (nullable = true) |-- people_fully_vaccinated_per_hundred: double (nullable = true) |-- new_vaccinations_smoothed_per_million: double (nullable = true) |-- stringency_index: double (nullable = true) |-- population: double (nullable = true) |-- population_density: double (nullable = true) |-- median_age: double (nullable = true) |-- aged_65_older: double (nullable = true) |-- aged_70_older: double (nullable = true) |-- gdp_per_capita: double (nullable = true) |-- extreme_poverty: double (nullable = true) |-- cardiovasc_death_rate: double (nullable = true) |-- diabetes_prevalence: double (nullable = true) |-- female_smokers: double (nullable = true) |-- male_smokers: double (nullable = true) |-- handwashing_facilities: double (nullable = true) |-- hospital_beds_per_thousand: double (nullable = true) |-- life_expectancy: double (nullable = true) |-- human_development_index: double (nullable = true)
# Print statistical informations about the dataframe
df_init.describe().toPandas().transpose()
| 0 | 1 | 2 | 3 | 4 | |
|---|---|---|---|---|---|
| summary | count | mean | stddev | min | max |
| iso_code | 85580 | None | None | ABW | ZWE |
| continent | 81451 | None | None | Africa | South America |
| location | 85580 | None | None | Afghanistan | Zimbabwe |
| date | 85580 | None | None | 2020-01-01 | 2021-05-03 |
| total_cases | 83470 | 832778.9957349946 | 5757477.472047493 | 1.0 | 1.52870507E8 |
| new_cases | 83468 | 5835.210092490535 | 36508.86533731523 | -74347.0 | 905992.0 |
| new_cases_smoothed | 82467 | 5814.830038548904 | 35814.991549206636 | -6223.0 | 826374.286 |
| total_deaths | 73790 | 23163.885973709173 | 137024.56622028106 | 1.0 | 3202523.0 |
| new_deaths | 73948 | 139.27223183858928 | 760.1312758706464 | -1918.0 | 17906.0 |
| new_deaths_smoothed | 82467 | 123.41811175379846 | 695.9858349571732 | -232.143 | 14435.143 |
| total_cases_per_million | 83019 | 10162.936744359886 | 19463.958179571233 | 0.001 | 171901.896 |
| new_cases_per_million | 83017 | 74.39997649878887 | 175.60479416028457 | -2153.437 | 8652.658 |
| new_cases_smoothed_per_million | 82021 | 74.5016326550511 | 149.09617539623656 | -276.825 | 2648.773 |
| total_deaths_per_million | 73352 | 226.73812504090714 | 397.8053083164839 | 0.001 | 2877.95 |
| new_deaths_per_million | 73510 | 1.5092371786151573 | 3.977495993528583 | -76.445 | 218.329 |
| new_deaths_smoothed_per_million | 82021 | 1.338711878665185 | 2.9391449349779664 | -10.921 | 63.14 |
| reproduction_rate | 69306 | 1.0182404120855324 | 0.35633672563155433 | -0.01 | 5.77 |
| icu_patients | 8691 | 1087.9978138303993 | 3033.991358320877 | 0.0 | 29990.0 |
| icu_patients_per_million | 8691 | 26.403532965136367 | 27.860613702241253 | 0.0 | 192.642 |
| hosp_patients | 10821 | 4832.752795490251 | 12433.561423445204 | 0.0 | 129637.0 |
| hosp_patients_per_million | 10821 | 173.69404269475964 | 216.330380008436 | 0.0 | 1532.573 |
| weekly_icu_admissions | 790 | 280.36216708860724 | 588.1153416060671 | 0.0 | 4037.019 |
| weekly_icu_admissions_per_million | 790 | 21.10003544303796 | 37.10012126341996 | 0.0 | 279.13 |
| weekly_hosp_admissions | 1298 | 3994.2353412943 | 11634.79089101694 | 0.0 | 116232.0 |
| weekly_hosp_admissions_per_million | 1298 | 115.32590986132524 | 230.24720305086265 | 0.0 | 2656.911 |
| new_tests | 38945 | 44019.99345230453 | 228082.89074859317 | -239172.0 | 3.2022805E7 |
| total_tests | 38652 | 5963484.145788058 | 2.702657004024681E7 | 0.0 | 4.13502739E8 |
| total_tests_per_thousand | 38652 | 227.4760523388175 | 495.3537686855608 | 0.0 | 6233.953 |
| new_tests_per_thousand | 38945 | 1.9372400564899257 | 15.289038589037073 | -23.01 | 2827.217 |
| new_tests_smoothed | 44625 | 41416.30742857143 | 148450.0041091272 | 0.0 | 4594014.0 |
| new_tests_smoothed_per_thousand | 44625 | 1.7755527619047682 | 4.810456516565619 | 0.0 | 405.595 |
| positive_rate | 42904 | 0.08897072534029671 | 0.09759395562535529 | 0.0 | 0.742 |
| tests_per_case | 42311 | 159.4553425823062 | 864.9287640966207 | 1.3 | 44258.7 |
| tests_units | 46079 | None | None | people tested | units unclear |
| total_vaccinations | 9551 | 1.4957485974452937E7 | 6.869741373731461E7 | 0.0 | 1.162067962E9 |
| people_vaccinated | 8910 | 9225179.1661055 | 3.9027792779540956E7 | 0.0 | 6.04516098E8 |
| people_fully_vaccinated | 6576 | 4772619.979622871 | 1.8868453614355136E7 | 1.0 | 2.75959041E8 |
| new_vaccinations | 8110 | 424912.7676942047 | 1698118.626609131 | 0.0 | 2.4728855E7 |
| new_vaccinations_smoothed | 15322 | 226123.2664143062 | 1159589.1827649393 | 0.0 | 2.0323434E7 |
| total_vaccinations_per_hundred | 9551 | 13.58571144382781 | 21.893459977692924 | 0.0 | 211.08 |
| people_vaccinated_per_hundred | 8910 | 9.51923007856343 | 13.923543885901049 | 0.0 | 111.32 |
| people_fully_vaccinated_per_hundred | 6576 | 5.181970802919693 | 9.721152822921756 | 0.0 | 99.76 |
| new_vaccinations_smoothed_per_million | 15322 | 2784.0497324109124 | 4620.116653779478 | 0.0 | 118759.0 |
| stringency_index | 72673 | 58.71260275480682 | 21.648871289508666 | 0.0 | 100.0 |
| population | 85029 | 1.283663044966776E8 | 6.902714046105045E8 | 809.0 | 7.794798729E9 |
| population_density | 79659 | 349.79385832101934 | 1703.2164207822327 | 0.137 | 20546.766 |
| median_age | 77078 | 30.521583331174902 | 9.115053092093216 | 15.1 | 48.2 |
| aged_65_older | 76198 | 8.772912753614143 | 6.223711873362256 | 1.144 | 27.049 |
| aged_70_older | 76646 | 5.556767946141541 | 4.248685782933569 | 0.526 | 18.493 |
| gdp_per_capita | 77418 | 19139.031322351304 | 19826.54402373803 | 661.24 | 116935.6 |
| extreme_poverty | 52700 | 13.35125616698124 | 19.943899746923773 | 0.1 | 77.6 |
| cardiovasc_death_rate | 78005 | 257.8088030126566 | 118.77751200084232 | 79.37 | 724.417 |
| diabetes_prevalence | 79158 | 7.821495237373721 | 3.9780571783350993 | 0.99 | 30.53 |
| female_smokers | 61115 | 10.519806561401039 | 10.402752297957854 | 0.1 | 44.0 |
| male_smokers | 60214 | 32.657161872646405 | 13.475414080788601 | 7.7 | 78.1 |
| handwashing_facilities | 39197 | 50.91309169068946 | 31.763061733684147 | 1.188 | 98.999 |
| hospital_beds_per_thousand | 71182 | 3.0294416846951417 | 2.4634261515052818 | 0.1 | 13.8 |
| life_expectancy | 81224 | 73.16502240718631 | 7.5497456561661656 | 53.28 | 86.75 |
| human_development_index | 77890 | 0.7271036590064779 | 0.15005860466163015 | 0.394 | 0.957 |
#Plotting correlation of target column new_deaths with the other columns of dataframe
df_init_c = df_init.na.fill(0)
location_of_analysis = "Italy"
columns_to_drop = ['weekly_icu_admissions', 'weekly_icu_admissions_per_million', 'population', 'extreme_poverty', 'handwashing_facilities']
df_init_c = df_init_c.drop(*columns_to_drop)
df_init_c = df_init_c.select("*").where(df_init.location == location_of_analysis).orderBy("date")
df_init_c = df_init_c.toPandas()
corrMatrix = df_init_c.corr().abs().round(4)
corrMatrix = corrMatrix[['new_deaths']]
plt.subplots(figsize=(10,15))
sns.heatmap(corrMatrix, annot=True, cmap="Blues")
<AxesSubplot:>
# Scatterplot for graphic data correlation visualization
fig1 = px.scatter(df_init_c, x="new_deaths", y="new_cases", color_discrete_sequence=['#7d87ff'], trendline="ols")
fig2 = px.scatter(df_init_c, x="new_deaths", y="icu_patients", color_discrete_sequence=['#1b2bf4'], trendline="ols")
fig3 = px.scatter(df_init_c, x="new_deaths", y="hosp_patients", color_discrete_sequence=['#1b2bf4'], trendline="ols")
fig1.show()
fig2.show()
fig3.show()
# df contains only rows before vaccinations
df_init = df_init.na.fill(0)
df = df_init.filter(df_init.total_vaccinations == 0)
df = df.select(to_date(col("date"),"yyyy-MM-dd").alias("date"),"new_cases","hosp_patients","icu_patients","new_deaths").where(df.location == location_of_analysis).orderBy("date")
# Define and add an ordered index column in dataframe
df = df.withColumn(
"index",
row_number().over(Window.orderBy(monotonically_increasing_id()))-1
)
# Show dataframe and schema
df.show()
df.printSchema()
# Convert dataframe to pandas
pandasDF = df.toPandas()
print("Total Rows 1° dataframe -before vaccinations-: %.3f" % len(pandasDF))
# This date array will be used for the succesive mapping with the index array (before vaccinations)
date_column = pandasDF.loc[:,'date']
x_date = date_column.values
# Plot dataframe
fig = px.line(pandasDF, x="date", y="new_deaths")
fig.show()
+----------+---------+-------------+------------+----------+-----+ | date|new_cases|hosp_patients|icu_patients|new_deaths|index| +----------+---------+-------------+------------+----------+-----+ |2020-01-31| 2.0| 0.0| 0.0| 0.0| 0| |2020-02-01| 0.0| 0.0| 0.0| 0.0| 1| |2020-02-02| 0.0| 0.0| 0.0| 0.0| 2| |2020-02-03| 0.0| 0.0| 0.0| 0.0| 3| |2020-02-04| 0.0| 0.0| 0.0| 0.0| 4| |2020-02-05| 0.0| 0.0| 0.0| 0.0| 5| |2020-02-06| 0.0| 0.0| 0.0| 0.0| 6| |2020-02-07| 1.0| 0.0| 0.0| 0.0| 7| |2020-02-08| 0.0| 0.0| 0.0| 0.0| 8| |2020-02-09| 0.0| 0.0| 0.0| 0.0| 9| |2020-02-10| 0.0| 0.0| 0.0| 0.0| 10| |2020-02-11| 0.0| 0.0| 0.0| 0.0| 11| |2020-02-12| 0.0| 0.0| 0.0| 0.0| 12| |2020-02-13| 0.0| 0.0| 0.0| 0.0| 13| |2020-02-14| 0.0| 0.0| 0.0| 0.0| 14| |2020-02-15| 0.0| 0.0| 0.0| 0.0| 15| |2020-02-16| 0.0| 0.0| 0.0| 0.0| 16| |2020-02-17| 0.0| 0.0| 0.0| 0.0| 17| |2020-02-18| 0.0| 0.0| 0.0| 0.0| 18| |2020-02-19| 0.0| 0.0| 0.0| 0.0| 19| +----------+---------+-------------+------------+----------+-----+ only showing top 20 rows root |-- date: date (nullable = true) |-- new_cases: double (nullable = false) |-- hosp_patients: double (nullable = false) |-- icu_patients: double (nullable = false) |-- new_deaths: double (nullable = false) |-- index: integer (nullable = true) Total Rows 1° dataframe -before vaccinations-: 331.000
# df2 contains all row after vaccination
df2 = df_init.filter(df_init.total_vaccinations > 0)
df2 = df2.select(to_date(col("date"),"yyyy-MM-dd").alias("date"),"new_cases","hosp_patients","icu_patients","new_deaths","total_vaccinations").where(df2.location == location_of_analysis).orderBy("date")
# Define and add an ordered index column in dataframe
df2 = df2.withColumn(
"index",
row_number().over(Window.orderBy(monotonically_increasing_id()))-1
)
# Show dataframe and schema
df2.show()
df2.printSchema()
# Convert dataframe to pandas
pandasDF2 = df2.toPandas()
print("Total Rows 2° dataframe -after vaccinationsi- : %.3f" % len(pandasDF2))
# This date array will be used for the succesive mapping with the index array (after vaccinations)
date_column2 = pandasDF2.loc[:,'date']
x_date2 = date_column2.values
# Plot dataframe new deaths
fig2 = px.line(pandasDF2, x="date", y="new_deaths")
fig2.show()
# Plot dataframe vaccinations
fig3 = px.line(pandasDF2, x="date", y="total_vaccinations")
fig3.show()
+----------+---------+-------------+------------+----------+------------------+-----+ | date|new_cases|hosp_patients|icu_patients|new_deaths|total_vaccinations|index| +----------+---------+-------------+------------+----------+------------------+-----+ |2020-12-27| 8937.0| 26151.0| 2580.0| 305.0| 7175.0| 0| |2020-12-28| 8581.0| 26497.0| 2565.0| 445.0| 8600.0| 1| |2020-12-29| 11210.0| 26211.0| 2549.0| 659.0| 9609.0| 2| |2020-12-30| 16202.0| 26094.0| 2528.0| 575.0| 14337.0| 3| |2020-12-31| 23477.0| 25706.0| 2555.0| 555.0| 39818.0| 4| |2021-01-01| 22210.0| 25375.0| 2553.0| 462.0| 50877.0| 5| |2021-01-02| 11825.0| 25517.0| 2569.0| 364.0| 89400.0| 6| |2021-01-03| 14245.0| 25658.0| 2583.0| 347.0| 124565.0| 7| |2021-01-04| 10798.0| 25896.0| 2579.0| 348.0| 193333.0| 8| |2021-01-05| 15375.0| 25964.0| 2569.0| 649.0| 273130.0| 9| |2021-01-06| 20326.0| 25745.0| 2571.0| 548.0| 338319.0| 10| |2021-01-07| 18416.0| 25878.0| 2587.0| 414.0| 430571.0| 11| |2021-01-08| 17529.0| 25900.0| 2587.0| 620.0| 526196.0| 12| |2021-01-09| 19976.0| 25853.0| 2593.0| 483.0| 613188.0| 13| |2021-01-10| 18625.0| 26042.0| 2615.0| 361.0| 673555.0| 14| |2021-01-11| 12530.0| 26245.0| 2642.0| 448.0| 754532.0| 15| |2021-01-12| 14242.0| 26348.0| 2636.0| 616.0| 836594.0| 16| |2021-01-13| 15773.0| 26104.0| 2579.0| 507.0| 931146.0| 17| |2021-01-14| 17243.0| 25667.0| 2557.0| 522.0| 1024900.0| 18| |2021-01-15| 16144.0| 25363.0| 2522.0| 477.0| 1114574.0| 19| +----------+---------+-------------+------------+----------+------------------+-----+ only showing top 20 rows root |-- date: date (nullable = true) |-- new_cases: double (nullable = false) |-- hosp_patients: double (nullable = false) |-- icu_patients: double (nullable = false) |-- new_deaths: double (nullable = false) |-- total_vaccinations: double (nullable = false) |-- index: integer (nullable = true) Total Rows 2° dataframe -after vaccinationsi- : 127.000
# Select features to include in the ML model
features = ["new_cases","icu_patients","hosp_patients","index"]
# Select input data for the generalized linear regression (training+test)
lr_data = df.select(col("new_deaths").alias("label"), *features)
# Select input data for the generalized linear regression (test)
test2 = df2.select(col("new_deaths").alias("label"), *features)
# Print schema for input data (training+test) before vaccination
lr_data.printSchema()
lr_data.show()
# Print schema for input data (test) after vaccination
test2.printSchema()
test2.show()
root |-- label: double (nullable = false) |-- new_cases: double (nullable = false) |-- icu_patients: double (nullable = false) |-- hosp_patients: double (nullable = false) |-- index: integer (nullable = true) +-----+---------+------------+-------------+-----+ |label|new_cases|icu_patients|hosp_patients|index| +-----+---------+------------+-------------+-----+ | 0.0| 2.0| 0.0| 0.0| 0| | 0.0| 0.0| 0.0| 0.0| 1| | 0.0| 0.0| 0.0| 0.0| 2| | 0.0| 0.0| 0.0| 0.0| 3| | 0.0| 0.0| 0.0| 0.0| 4| | 0.0| 0.0| 0.0| 0.0| 5| | 0.0| 0.0| 0.0| 0.0| 6| | 0.0| 1.0| 0.0| 0.0| 7| | 0.0| 0.0| 0.0| 0.0| 8| | 0.0| 0.0| 0.0| 0.0| 9| | 0.0| 0.0| 0.0| 0.0| 10| | 0.0| 0.0| 0.0| 0.0| 11| | 0.0| 0.0| 0.0| 0.0| 12| | 0.0| 0.0| 0.0| 0.0| 13| | 0.0| 0.0| 0.0| 0.0| 14| | 0.0| 0.0| 0.0| 0.0| 15| | 0.0| 0.0| 0.0| 0.0| 16| | 0.0| 0.0| 0.0| 0.0| 17| | 0.0| 0.0| 0.0| 0.0| 18| | 0.0| 0.0| 0.0| 0.0| 19| +-----+---------+------------+-------------+-----+ only showing top 20 rows root |-- label: double (nullable = false) |-- new_cases: double (nullable = false) |-- icu_patients: double (nullable = false) |-- hosp_patients: double (nullable = false) |-- index: integer (nullable = true) +-----+---------+------------+-------------+-----+ |label|new_cases|icu_patients|hosp_patients|index| +-----+---------+------------+-------------+-----+ |305.0| 8937.0| 2580.0| 26151.0| 0| |445.0| 8581.0| 2565.0| 26497.0| 1| |659.0| 11210.0| 2549.0| 26211.0| 2| |575.0| 16202.0| 2528.0| 26094.0| 3| |555.0| 23477.0| 2555.0| 25706.0| 4| |462.0| 22210.0| 2553.0| 25375.0| 5| |364.0| 11825.0| 2569.0| 25517.0| 6| |347.0| 14245.0| 2583.0| 25658.0| 7| |348.0| 10798.0| 2579.0| 25896.0| 8| |649.0| 15375.0| 2569.0| 25964.0| 9| |548.0| 20326.0| 2571.0| 25745.0| 10| |414.0| 18416.0| 2587.0| 25878.0| 11| |620.0| 17529.0| 2587.0| 25900.0| 12| |483.0| 19976.0| 2593.0| 25853.0| 13| |361.0| 18625.0| 2615.0| 26042.0| 14| |448.0| 12530.0| 2642.0| 26245.0| 15| |616.0| 14242.0| 2636.0| 26348.0| 16| |507.0| 15773.0| 2579.0| 26104.0| 17| |522.0| 17243.0| 2557.0| 25667.0| 18| |477.0| 16144.0| 2522.0| 25363.0| 19| +-----+---------+------------+-------------+-----+ only showing top 20 rows
# Divide data in 70% for training and 30% for testing (before vaccinations)
(training, test) = lr_data.randomSplit([.7, .3])
# Create a array with features (unscaled)
vectorAssembler = VectorAssembler(inputCols=features, outputCol="unscaled_features")
# Scale the features array by normalizing each feature to have unit standard deviation.
standardScaler = StandardScaler(inputCol="unscaled_features", outputCol="features")
# Define GLR with params
lr = GeneralizedLinearRegression()
# Define stages for pipeline
stages = [vectorAssembler, standardScaler, lr]
# Define the pipeline with stages
pipeline = Pipeline(stages=stages)
# We use a ParamGridBuilder to construct a grid of parameters to search over.
param_grid = ParamGridBuilder().addGrid(lr.regParam, [1, 0.001]).build()
# We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
# This will allow us to jointly choose parameters for all Pipeline stages.
# A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
# Note that the evaluator here is a RegressionEvaluator and its default metric is rmse
# where rmse stands for Root Mean Square Error
cv = CrossValidator(estimator=pipeline, estimatorParamMaps=param_grid, evaluator=RegressionEvaluator(), numFolds=10)
# NOTE: ALL THIS CODE IS FOR DATA BEFORE VACCINATIONS (TRAINING+TEST)
# Run cross-validation, and choose the best set of parameters for the model
model = cv.fit(training)
# Model creates prediction on test data by using the best model found with cross-validation
prediction = model.transform(test)
prediction = prediction.orderBy("index")
prediction.show()
# Shows number of predicted rows
pandasDF = prediction.toPandas()
print("Total Rows -before vaccinations- : %.3f" % len(pandasDF))
# Take index column values and put in x_array for the next mapping with date array to create an accurate plot
index_column = pandasDF.loc[:,'index']
x_array = index_column.values
# Define list for the date and indexes(indexes of rows selected by GLR algorithm) mapping.
list_x=[]
for i in range(len(x_array)): # for i from start to lenght of indexes array created by GLR
list_x.append(x_date[x_array[i]]) # push date in list based on indexes chosen by GLR for correct mapping of date and index
# see the correct mapping in the plot, in the x axis
# Put columns label and prediction in y1_array and y2_array arrays for plotting
label_column = pandasDF.loc[:,'label']
y1_array = label_column.values
index_column = pandasDF.loc[:,'prediction']
y2_array = index_column.values
# Plot original data and predicted data
fig = go.Figure()
fig.add_trace(go.Scatter(x=list_x, y=y1_array,
mode='lines',
name='new_covid-19_deaths'))
fig.add_trace(go.Scatter(x=list_x, y=y2_array,
mode='lines',
name='new_covid-19_deaths_predicted'))
fig.show()
+-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ |label|new_cases|icu_patients|hosp_patients|index| unscaled_features| features| prediction| +-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ | 0.0| 0.0| 0.0| 0.0| 3| (4,[3],[3.0])|(4,[3],[0.0315751...|-9.104579437450319| | 0.0| 0.0| 0.0| 0.0| 6| (4,[3],[6.0])|(4,[3],[0.0631503...| -9.1319248065238| | 0.0| 0.0| 0.0| 0.0| 12| (4,[3],[12.0])|(4,[3],[0.1263006...|-9.186615544670763| | 0.0| 0.0| 0.0| 0.0| 13| (4,[3],[13.0])|(4,[3],[0.1368256...|-9.195730667695257| | 0.0| 0.0| 0.0| 0.0| 14| (4,[3],[14.0])|(4,[3],[0.1473507...| -9.20484579071975| | 0.0| 0.0| 0.0| 0.0| 16| (4,[3],[16.0])|(4,[3],[0.1684008...|-9.223076036768738| | 2.0| 131.0| 36.0| 164.0| 26|[131.0,36.0,164.0...|[0.01346131007941...|-3.114129154261537| | 5.0| 202.0| 56.0| 304.0| 27|[202.0,56.0,304.0...|[0.02075713462627...| 0.50461799780002| | 18.0| 342.0| 166.0| 908.0| 31|[342.0,166.0,908....|[0.03514326753557...|20.152549040516014| | 36.0| 1247.0| 567.0| 3218.0| 36|[1247.0,567.0,321...|[0.12813934098496...| 91.72153314313375| |196.0| 2313.0| 1028.0| 6866.0| 40|[2313.0,1028.0,68...|[0.23767946728004...|177.65195587290464| |175.0| 3497.0| 1518.0| 9890.0| 43|[3497.0,1518.0,98...|[0.35934504845581...|265.77085230441935| |368.0| 3590.0| 1672.0| 11335.0| 44|[3590.0,1672.0,11...|[0.36890155103128...| 295.6799787889233| |627.0| 5986.0| 2655.0| 18675.0| 49|[5986.0,2655.0,18...|[0.61510993996469...| 477.1504384412424| |793.0| 6557.0| 2857.0| 20565.0| 50|[6557.0,2857.0,20...|[0.67378481061618...| 515.7456205332933| |651.0| 5560.0| 3009.0| 22855.0| 51|[5560.0,3009.0,22...|[0.57133499268354...| 549.9639294414459| |889.0| 5974.0| 3856.0| 30532.0| 57|[5974.0,3856.0,30...|[0.61387684285818...| 713.5930381028747| |756.0| 5217.0| 3906.0| 31292.0| 58|[5217.0,3906.0,31...|[0.53608896705576...| 725.4594001811868| |812.0| 4050.0| 3981.0| 31776.0| 59|[4050.0,3981.0,31...|[0.41617027344754...| 740.8794296036061| |766.0| 4585.0| 4068.0| 32809.0| 63|[4585.0,4068.0,32...|[0.47114585277950...| 757.881319629163| +-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ only showing top 20 rows Total Rows -before vaccinations- : 104.000
# Call Evaulator for evaluating the model
eval = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")
# Root Mean Square Error
rmse = eval.evaluate(prediction)
print("RMSE: %.3f" % rmse)
# Mean Square Error
mse = eval.evaluate(prediction, {eval.metricName: "mse"})
print("MSE: %.3f" % mse)
# Mean Absolute Error
mae = eval.evaluate(prediction, {eval.metricName: "mae"})
print("MAE: %.3f" % mae)
# R2 - coefficient of determination (R2-->1 perfect prediction, R2=1 so perfect that there is an error in the model)
R2 = eval.evaluate(prediction, {eval.metricName: "r2"})
print("R2: %.3f" %R2)
RMSE: 75.117 MSE: 5642.502 MAE: 46.026 R2: 0.929
# NOTE: ALL THIS CODE IS FOR DATA AFTER VACCINATIONS (TESTING MODEL)
# Applying model to data after vaccinations for testing and prediction
prediction2 = model.transform(test2)
prediction2 = prediction2.orderBy("index")
prediction2.show()
# Shows number of predicted rows
pandasDF2 = prediction2.toPandas()
print("Total Rows -after vaccinations- : %.3f" % len(pandasDF2))
# Take index column values and put in x_array2 for the next mapping with date array to create an accurate plot
index_column2 = pandasDF2.loc[:,'index']
x_array2 = index_column2.values
# Define list for the date and indexes(indexes of rows selected by GLR algorithm) mapping.
list_x2=[]
for j in range(len(x_array2)): # for i from start to lenght of indexes array created by GLR
list_x2.append(x_date2[x_array2[j]]) # push date in list based on indexes chosen by GLR for correct mapping of date and index
# see the correct mapping in the plot, in the x axis
# Put columns label and prediction in y1_array2 and y2_array2 arrays for plotting
label_column2 = pandasDF2.loc[:,'label']
y1_array2 = label_column2.values
index_column2 = pandasDF2.loc[:,'prediction']
y2_array2 = index_column2.values
# Plot original data and predicted data
fig2 = go.Figure()
fig2.add_trace(go.Scatter(x=list_x2, y=y1_array2,
mode='lines',
name='new_covid-19_deaths -v'))
fig2.add_trace(go.Scatter(x=list_x2, y=y2_array2,
mode='lines',
name='new_covid-19_deaths_predicted -v'))
fig2.show()
+-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ |label|new_cases|icu_patients|hosp_patients|index| unscaled_features| features| prediction| +-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ |305.0| 8937.0| 2580.0| 26151.0| 0|[8937.0,2580.0,26...|[0.91834907007425...| 489.1636274603039| |445.0| 8581.0| 2565.0| 26497.0| 1|[8581.0,2565.0,26...|[0.88176718924775...| 488.520004660835| |659.0| 11210.0| 2549.0| 26211.0| 2|[11210.0,2549.0,2...|[1.15191821366592...| 481.2611337111114| |575.0| 16202.0| 2528.0| 26094.0| 3|[16202.0,2528.0,2...|[1.66488660997460...| 470.5672575061037| |555.0| 23477.0| 2555.0| 25706.0| 4|[23477.0,2555.0,2...|[2.41245173079705...|463.42089554850156| |462.0| 22210.0| 2553.0| 25375.0| 5|[22210.0,2553.0,2...|[2.28225722796790...| 463.613284543112| |364.0| 11825.0| 2569.0| 25517.0| 6|[11825.0,2569.0,2...|[1.21511444037462...|481.01619481350303| |347.0| 14245.0| 2583.0| 25658.0| 7|[14245.0,2583.0,2...|[1.46378902352106...| 480.4371779883596| |348.0| 10798.0| 2579.0| 25896.0| 8|[10798.0,2579.0,2...|[1.10958187967570...|485.42047538608466| |649.0| 15375.0| 2569.0| 25964.0| 9|[15375.0,2569.0,2...|[1.57990566771753...|477.74768647731327| |548.0| 20326.0| 2571.0| 25745.0| 10|[20326.0,2571.0,2...|[2.08866098224563...|470.42175656608674| |414.0| 18416.0| 2587.0| 25878.0| 11|[18416.0,2587.0,2...|[1.89239302612593...|476.10476024678616| |620.0| 17529.0| 2587.0| 25900.0| 12|[17529.0,2587.0,2...|[1.80124659833630...| 477.4001842255914| |483.0| 19976.0| 2593.0| 25853.0| 13|[19976.0,2593.0,2...|[2.05269564997239...|474.80516664017307| |361.0| 18625.0| 2615.0| 26042.0| 14|[18625.0,2615.0,2...|[1.91386946739766...| 480.8868740950102| |448.0| 12530.0| 2642.0| 26245.0| 15|[12530.0,2642.0,2...|[1.28755889538216...| 494.3640530399432| |616.0| 14242.0| 2636.0| 26348.0| 16|[14242.0,2636.0,2...|[1.46348074924443...| 491.4130130923759| |507.0| 15773.0| 2579.0| 26104.0| 17|[15773.0,2579.0,2...|[1.62080338841682...|479.24786009378516| |522.0| 17243.0| 2557.0| 25667.0| 18|[17243.0,2557.0,2...|[1.77185778396445...| 472.0661108180667| |477.0| 16144.0| 2522.0| 25363.0| 19|[16144.0,2522.0,2...|[1.65892664062646...|466.83404063606395| +-----+---------+------------+-------------+-----+--------------------+--------------------+------------------+ only showing top 20 rows Total Rows -after vaccinations- : 127.000
# Call Evaulator for evaluating the model
eval = RegressionEvaluator(labelCol="label", predictionCol="prediction", metricName="rmse")
# Root Mean Square Error
rmse = eval.evaluate(prediction2)
print("RMSE: %.3f" % rmse)
# Mean Square Error
mse = eval.evaluate(prediction2, {eval.metricName: "mse"})
print("MSE: %.3f" % mse)
# Mean Absolute Error
mae = eval.evaluate(prediction2, {eval.metricName: "mae"})
print("MAE: %.3f" % mae)
# R2 - coefficient of determination (R2-->1 perfect prediction, R2=1 so perfect that there is an error in the model)
R2 = eval.evaluate(prediction, {eval.metricName: "r2"})
print("R2: %.3f" %R2)
RMSE: 180.718 MSE: 32658.934 MAE: 150.669 R2: 0.929